{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "import numpy\n",
    "\n",
    "def matrix_factorization(R, P, Q, K, steps=5000, alpha=0.0002, beta=0.02):\n",
    "    \n",
    "    Q = Q.T\n",
    "    \n",
    "    for step in xrange(steps):\n",
    "        for i in xrange(len(R)):\n",
    "            for j in xrange(len(R[i])):\n",
    "                if R[i][j] > 0:\n",
    "                    eij = R[i][j] - numpy.dot(P[i,:],Q[:,j])\n",
    "                    for k in xrange(K):\n",
    "                        P[i][k] = P[i][k] + alpha * (2 * eij * Q[k][j] - beta * P[i][k])\n",
    "                        Q[k][j] = Q[k][j] + alpha * (2 * eij * P[i][k] - beta * Q[k][j])\n",
    "        eR = numpy.dot(P,Q)\n",
    "        e = 0\n",
    "        for i in xrange(len(R)):\n",
    "            for j in xrange(len(R[i])):\n",
    "                if R[i][j] > 0:\n",
    "                    e = e + pow(R[i][j] - numpy.dot(P[i,:],Q[:,j]), 2)\n",
    "                    # normalization\n",
    "                    for k in xrange(K):\n",
    "                        e = e + (beta/2) * (pow(P[i][k],2) + pow(Q[k][j],2))\n",
    "        if step % 100 == 0:\n",
    "            print \"iteration\",step,\"error:\",e\n",
    "        if e < 0.001:\n",
    "            break\n",
    "    \n",
    "    \n",
    "    return P, Q.T\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iteration 0 error: 82.3618339348\n",
      "iteration 100 error: 60.5750519946\n",
      "iteration 200 error: 44.8458321227\n",
      "iteration 300 error: 35.4508652198\n",
      "iteration 400 error: 29.8323886003\n",
      "iteration 500 error: 25.7903644191\n",
      "iteration 600 error: 22.2151767539\n",
      "iteration 700 error: 18.68981101\n",
      "iteration 800 error: 15.172425838\n",
      "iteration 900 error: 11.835455407\n",
      "iteration 1000 error: 8.92758960138\n",
      "iteration 1100 error: 6.63335937602\n",
      "iteration 1200 error: 4.99094459795\n",
      "iteration 1300 error: 3.9066529003\n",
      "iteration 1400 error: 3.22800909781\n",
      "iteration 1500 error: 2.81045326254\n",
      "iteration 1600 error: 2.54753015178\n",
      "iteration 1700 error: 2.37198997028\n",
      "iteration 1800 error: 2.24527470865\n",
      "iteration 1900 error: 2.14651292006\n",
      "iteration 2000 error: 2.06469906237\n",
      "iteration 2100 error: 1.99401988511\n",
      "iteration 2200 error: 1.9313156269\n",
      "iteration 2300 error: 1.87477371951\n",
      "iteration 2400 error: 1.8232751681\n",
      "iteration 2500 error: 1.77607024557\n",
      "iteration 2600 error: 1.73261627208\n",
      "iteration 2700 error: 1.69249454259\n",
      "iteration 2800 error: 1.65536613695\n",
      "iteration 2900 error: 1.62094719299\n",
      "iteration 3000 error: 1.58899422644\n",
      "iteration 3100 error: 1.55929485631\n",
      "iteration 3200 error: 1.53166158618\n",
      "iteration 3300 error: 1.50592740754\n",
      "iteration 3400 error: 1.48194254898\n",
      "iteration 3500 error: 1.45957198213\n",
      "iteration 3600 error: 1.43869344992\n",
      "iteration 3700 error: 1.41919586905\n",
      "iteration 3800 error: 1.40097800989\n",
      "iteration 3900 error: 1.38394738768\n",
      "iteration 4000 error: 1.36801931902\n",
      "iteration 4100 error: 1.3531161103\n",
      "iteration 4200 error: 1.3391663536\n",
      "iteration 4300 error: 1.32610431143\n",
      "iteration 4400 error: 1.31386937608\n",
      "iteration 4500 error: 1.30240559215\n",
      "iteration 4600 error: 1.29166123313\n",
      "iteration 4700 error: 1.28158842479\n",
      "iteration 4800 error: 1.27214280891\n",
      "iteration 4900 error: 1.26328324242\n",
      "iteration 5000 error: 1.25497152739\n",
      "iteration 5100 error: 1.24717216826\n",
      "iteration 5200 error: 1.23985215291\n",
      "iteration 5300 error: 1.23298075498\n",
      "iteration 5400 error: 1.22652935481\n",
      "iteration 5500 error: 1.22047127711\n",
      "iteration 5600 error: 1.21478164336\n",
      "iteration 5700 error: 1.20943723737\n",
      "iteration 5800 error: 1.2044163827\n",
      "iteration 5900 error: 1.19969883055\n",
      "iteration 6000 error: 1.19526565709\n",
      "iteration 6100 error: 1.19109916942\n",
      "iteration 6200 error: 1.18718281899\n",
      "iteration 6300 error: 1.18350112216\n",
      "iteration 6400 error: 1.18003958679\n",
      "iteration 6500 error: 1.17678464472\n",
      "iteration 6600 error: 1.1737235892\n",
      "iteration 6700 error: 1.17084451719\n",
      "iteration 6800 error: 1.16813627581\n",
      "iteration 6900 error: 1.16558841273\n",
      "iteration 7000 error: 1.16319113021\n",
      "iteration 7100 error: 1.1609352423\n",
      "iteration 7200 error: 1.15881213517\n",
      "iteration 7300 error: 1.1568137301\n",
      "iteration 7400 error: 1.15493244914\n",
      "iteration 7500 error: 1.15316118301\n",
      "iteration 7600 error: 1.15149326128\n",
      "iteration 7700 error: 1.14992242448\n",
      "iteration 7800 error: 1.1484427981\n",
      "iteration 7900 error: 1.14704886836\n",
      "iteration 8000 error: 1.14573545949\n",
      "iteration 8100 error: 1.14449771259\n",
      "iteration 8200 error: 1.14333106579\n",
      "iteration 8300 error: 1.14223123579\n",
      "iteration 8400 error: 1.14119420047\n",
      "iteration 8500 error: 1.14021618272\n",
      "iteration 8600 error: 1.13929363528\n",
      "iteration 8700 error: 1.13842322652\n",
      "iteration 8800 error: 1.13760182714\n",
      "iteration 8900 error: 1.13682649769\n",
      "iteration 9000 error: 1.13609447695\n",
      "iteration 9100 error: 1.13540317094\n",
      "iteration 9200 error: 1.13475014271\n",
      "iteration 9300 error: 1.13413310268\n",
      "iteration 9400 error: 1.13354989969\n",
      "iteration 9500 error: 1.13299851249\n",
      "iteration 9600 error: 1.13247704187\n",
      "iteration 9700 error: 1.13198370317\n",
      "iteration 9800 error: 1.13151681937\n",
      "iteration 9900 error: 1.13107481449\n"
     ]
    }
   ],
   "source": [
    "R = [\n",
    "    [5,3,0,1],\n",
    "    [4,0,0,1],\n",
    "    [1,1,0,5],\n",
    "    [1,0,0,4],\n",
    "    [0,1,5,4],\n",
    "]\n",
    "\n",
    "R = numpy.array(R)\n",
    "\n",
    "N = len(R)\n",
    "M = len(R[0])\n",
    "K = 3\n",
    "\n",
    "P = numpy.random.rand(N,K) # user-latent matrix\n",
    "Q = numpy.random.rand(M,K) # item-latent matrix\n",
    "\n",
    "nP, nQ = matrix_factorization(R, P, Q, K, steps=10000)\n",
    "nR = numpy.dot(nP, nQ.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[5, 3, 0, 1],\n",
       "       [4, 0, 0, 1],\n",
       "       [1, 1, 0, 5],\n",
       "       [1, 0, 0, 4],\n",
       "       [0, 1, 5, 4]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "R"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 4.97813151,  2.97773525,  4.57866262,  1.00164119],\n",
       "       [ 3.97965774,  2.36973846,  3.84246187,  0.9997563 ],\n",
       "       [ 1.0219862 ,  0.94643827,  5.69738167,  4.96918894],\n",
       "       [ 0.98939012,  0.82642272,  4.66483239,  3.97985017],\n",
       "       [ 1.35846187,  1.06372278,  4.96619056,  4.00240827]])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[-0.25827397,  2.38136068],\n",
       "        [-0.11245036,  1.85982849],\n",
       "        [ 1.82778173,  0.81810458],\n",
       "        [ 1.45250616,  0.68658217],\n",
       "        [ 1.25475026,  1.45701111]]), array([[-0.33123109,  2.0881403 ],\n",
       "        [-0.14150625,  1.144348  ],\n",
       "        [ 1.13011239,  2.3552181 ],\n",
       "        [ 2.41435452,  0.68244396]]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nP,nQ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 3.92082987,  2.14420344,  4.25322017,  0.99773368])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "numpy.dot(P[1,:],Q.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
